{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install -r ../requirements.txt -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "felafax package imported successfully\n"
     ]
    }
   ],
   "source": [
    "import importlib\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# Add the current directory and its parent to the Python path.\n",
    "# This allows importing modules from these directories.\n",
    "sys.path.append(os.path.abspath(os.getcwd()))\n",
    "sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import llama3_jax\n",
    "from llama3_jax.trainer_engine import setup\n",
    "setup.setup_environment(base_dir=\"/mnt/persistent-disk/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reloaded all felafax modules.\n"
     ]
    }
   ],
   "source": [
    "from llama3_jax.trainer_engine import (automodel_lib, checkpoint_lib,\n",
    "                                       convert_lib, llama_config, jax_utils, trainer_lib,\n",
    "                                       utils)\n",
    "setup.reload_modules(\"llama3_jax\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union\n",
    "\n",
    "import chex\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import optax\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import default_data_collator\n",
    "from huggingface_hub import snapshot_download\n",
    "import shutil\n",
    "from datetime import datetime\n",
    "import gzip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select a supported model from above list to use!\n",
    "MODEL_NAME = \"llama-3.1-8B-Instruct-JAX\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading model llama-3.1-8B-Instruct-JAX...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fetching 8 files: 100%|██████████| 8/8 [08:38<00:00, 64.77s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llama-3.1-8B-Instruct-JAX was downloaded to /mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-Instruct-JAX/snapshots/12d9565c6c550893fd3c0ab62c2b91b16acf1218/llama-3.1-8B-Instruct-JAX.flax.\n"
     ]
    }
   ],
   "source": [
    "model_path, model, model_configurator, tokenizer = (\n",
    "    automodel_lib.AutoJAXModelForCausalLM.from_pretrained(MODEL_NAME))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):\n",
    "    # Define Alpaca prompt template\n",
    "    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "    \n",
    "    ### Instruction: {}\n",
    "    \n",
    "    ### Input: {}\n",
    "    \n",
    "    ### Response: {}\"\"\"\n",
    "    \n",
    "    EOS_TOKEN = tokenizer.eos_token\n",
    "    \n",
    "    # Defines formatting function.\n",
    "    def _format_prompts(examples):\n",
    "        instructions = examples[\"instruction\"]\n",
    "        inputs = examples[\"input\"]\n",
    "        outputs = examples[\"output\"]\n",
    "        texts = []\n",
    "        for instruction, input, output in zip(instructions, inputs, outputs):\n",
    "            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN\n",
    "            texts.append(text)\n",
    "        return {\"text\": texts}\n",
    "\n",
    "    def _tokenize(examples):\n",
    "        tokenized = tokenizer(examples[\"text\"], truncation=True, padding=\"max_length\", max_length=seq_length+1)\n",
    "        return {\n",
    "            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],\n",
    "            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],\n",
    "            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]\n",
    "        }\n",
    "\n",
    "    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:\n",
    "        \"\"\"\n",
    "        Collates batch items and converts PyTorch tensors to JAX arrays.\n",
    "        Applies default_data_collator, then converts tensors to JAX format.\n",
    "        \"\"\"\n",
    "        collated = default_data_collator(batch)\n",
    "        jax_batch = {}\n",
    "        for key, value in collated.items():\n",
    "            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value\n",
    "        \n",
    "        return jax_batch\n",
    "\n",
    "    # Load and preprocess the dataset\n",
    "    dataset = load_dataset(\"yahma/alpaca-cleaned\", split=\"train\")\n",
    "    if max_examples:\n",
    "        dataset = dataset.select(range(max_examples))\n",
    "    dataset = dataset.map(_format_prompts, batched=True)\n",
    "\n",
    "    # Create train and test dataset.\n",
    "    ds = dataset.train_test_split(test_size=0.15)\n",
    "    for split in ['train', 'test']:\n",
    "        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)\n",
    "\n",
    "    # Create DataLoaders\n",
    "    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)\n",
    "    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)\n",
    "    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)\n",
    "\n",
    "    return train_dataloader, test_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading readme: 100%|██████████| 11.6k/11.6k [00:00<00:00, 23.5MB/s]\n",
      "Downloading data: 100%|██████████| 44.3M/44.3M [00:00<00:00, 118MB/s] \n",
      "Generating train split: 51760 examples [00:00, 150940.12 examples/s]\n",
      "Map: 100%|██████████| 32/32 [00:00<00:00, 7448.26 examples/s]\n",
      "Map: 100%|██████████| 27/27 [00:00<00:00, 1405.09 examples/s]\n",
      "Map: 100%|██████████| 5/5 [00:00<00:00, 945.26 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input tokens shape: (1, 32)\n",
      "Target mask shape: (1, 32)\n"
     ]
    }
   ],
   "source": [
    "def test_dataset_pipeline(tokenizer):\n",
    "    \"\"\"Print shapes of first batch to verify dataset pipeline.\"\"\"\n",
    "    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)\n",
    "    batch = next(iter(train_loader))\n",
    "    print(\"Input tokens shape:\", batch['input_tokens'].shape)\n",
    "    print(\"Target mask shape:\", batch['target_tokens'].shape)\n",
    "test_dataset_pipeline(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "@chex.dataclass(frozen=True)\n",
    "class TrainerConfig:\n",
    "    learning_rate: float = 1e-4\n",
    "    num_epochs: int = 1\n",
    "    max_steps: int | None = 100\n",
    "    batch_size: int = 32\n",
    "    seq_length: int = 64\n",
    "    dataset_size_limit: int | None = None\n",
    "    print_every_n_steps: int = 5\n",
    "    eval_every_n_steps: int = 1000\n",
    "    max_eval_steps: int | None = 1\n",
    "\n",
    "\n",
    "trainer_config = TrainerConfig()\n",
    "optimizer = optax.sgd(trainer_config.learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 51760/51760 [00:00<00:00, 95472.40 examples/s] \n",
      "Map: 100%|██████████| 43996/43996 [00:21<00:00, 2006.26 examples/s]\n",
      "Map: 100%|██████████| 7764/7764 [00:03<00:00, 2052.44 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# Prepare dataset\n",
    "train_dataloader, val_dataloader = get_dataset(\n",
    "    tokenizer=tokenizer,\n",
    "    batch_size=trainer_config.batch_size,\n",
    "    seq_length=trainer_config.seq_length,\n",
    "    max_examples=trainer_config.dataset_size_limit,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Configuration Summary:\n",
      "Total samples: 43996\n",
      "Batch size: 32\n",
      "Number of epochs: 1\n",
      "Steps per epoch: 1375\n",
      "Total training steps: 100\n",
      "*Note*: Total steps limited by max_steps setting (100)\n"
     ]
    }
   ],
   "source": [
    "# Calculate and print training steps information\n",
    "total_samples = len(train_dataloader.dataset)\n",
    "batch_size = trainer_config.batch_size\n",
    "steps_per_epoch = (total_samples + batch_size - 1) // batch_size\n",
    "total_steps = steps_per_epoch * trainer_config.num_epochs\n",
    "\n",
    "if trainer_config.max_steps:\n",
    "    total_steps = min(total_steps, trainer_config.max_steps)\n",
    "\n",
    "print(\"\\nTraining Configuration Summary:\")\n",
    "print(f\"Total samples: {total_samples}\")\n",
    "print(f\"Batch size: {batch_size}\")\n",
    "print(f\"Number of epochs: {trainer_config.num_epochs}\")\n",
    "print(f\"Steps per epoch: {steps_per_epoch}\")\n",
    "print(f\"Total training steps: {total_steps}\")\n",
    "if trainer_config.max_steps and total_steps == trainer_config.max_steps:\n",
    "    print(\n",
    "        f\"*Note*: Total steps limited by max_steps setting ({trainer_config.max_steps})\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading causal language model...\n"
     ]
    }
   ],
   "source": [
    "trainer = trainer_lib.CausalLMTrainer(\n",
    "    model=model,\n",
    "    model_ckpt_path=model_path,\n",
    "    model_configurator=model_configurator,\n",
    "    optimizer=optimizer,\n",
    "    training_config=trainer_config,\n",
    "    mesh=jax_utils.MESH,\n",
    "    model_name=MODEL_NAME,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start time: 1725751128.2981\n",
      "Starting epoch 0 of training...\n",
      "Epoch 0, Step 0, Train Loss: 3.2309, Accuracy: 0.3906\n",
      "Eval Step 0, Loss: 2.8140, Accuracy: 0.4531\n",
      "Evaluation complete. Average Loss: 2.8140, Average Accuracy: 0.4531\n",
      "Epoch 0, Step 0, Eval Loss: 2.8140, Accuracy: 0.4531\n",
      "Epoch 0, Step 5, Train Loss: 1.9765, Accuracy: 0.5938\n",
      "Epoch 0, Step 10, Train Loss: 1.5610, Accuracy: 0.6562\n",
      "Epoch 0, Step 15, Train Loss: 1.5088, Accuracy: 0.6719\n",
      "Epoch 0, Step 20, Train Loss: 1.1654, Accuracy: 0.7344\n",
      "Epoch 0, Step 25, Train Loss: 0.8645, Accuracy: 0.7969\n",
      "Epoch 0, Step 30, Train Loss: 0.9819, Accuracy: 0.8125\n",
      "Epoch 0, Step 35, Train Loss: 0.9629, Accuracy: 0.7656\n",
      "Epoch 0, Step 40, Train Loss: 0.6825, Accuracy: 0.8438\n",
      "Epoch 0, Step 45, Train Loss: 0.4973, Accuracy: 0.8438\n",
      "Epoch 0, Step 50, Train Loss: 0.5720, Accuracy: 0.8438\n",
      "Epoch 0, Step 55, Train Loss: 0.7773, Accuracy: 0.8438\n",
      "Epoch 0, Step 60, Train Loss: 0.4638, Accuracy: 0.9062\n",
      "Epoch 0, Step 65, Train Loss: 0.6748, Accuracy: 0.8438\n",
      "Epoch 0, Step 70, Train Loss: 0.5271, Accuracy: 0.8281\n",
      "Epoch 0, Step 75, Train Loss: 1.1770, Accuracy: 0.7500\n",
      "Epoch 0, Step 80, Train Loss: 0.7587, Accuracy: 0.7812\n",
      "Epoch 0, Step 85, Train Loss: 0.8808, Accuracy: 0.7969\n",
      "Epoch 0, Step 90, Train Loss: 1.1542, Accuracy: 0.7969\n",
      "Epoch 0, Step 95, Train Loss: 1.0281, Accuracy: 0.7969\n",
      "Epoch 0, Step 100, Train Loss: 0.8921, Accuracy: 0.7500\n",
      "End time: 1725752458.9299\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "start_time = time.time()\n",
    "print(f\"Start time: {start_time:.4f}\")\n",
    "\n",
    "state = trainer.train(train_dataloader, val_dataloader, run_jitted=True)\n",
    "\n",
    "end_time = time.time()\n",
    "print(f\"End time: {end_time:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution time: 1330.6318 seconds\n"
     ]
    }
   ],
   "source": [
    "# Calculate and print elapsed time\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"Execution time: {elapsed_time:.4f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constants for paths\n",
    "FELAFAX_DIR = \"/mnt/persistent-disk\"\n",
    "EXPORT_DIR = os.path.join(FELAFAX_DIR, \"export\")\n",
    "HF_EXPORT_DIR = os.path.join(FELAFAX_DIR, \"hf_export\")\n",
    "\n",
    "current_datetime = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "GCS_DIR = f\"/home/felafax-storage/checkpoints/{MODEL_NAME}/{current_datetime}/\"\n",
    "\n",
    "# Ensure directories exist\n",
    "utils.makedirs(EXPORT_DIR, exist_ok=True)\n",
    "utils.makedirs(HF_EXPORT_DIR, exist_ok=True)\n",
    "utils.makedirs(GCS_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint to /mnt/persistent-disk/export/llama-3.1-8B-Instruct-JAX...\n",
      "Checkpoint saved to /mnt/persistent-disk/export/llama-3.1-8B-Instruct-JAX.\n"
     ]
    }
   ],
   "source": [
    "flax_checkpoint_path = os.path.join(EXPORT_DIR, MODEL_NAME)\n",
    "trainer.save_checkpoint(state, path=flax_checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the checkpoint in a Llama model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 33/33 [00:20<00:00,  1.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving in the Transformers format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00,  8.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Copied README.md to /mnt/persistent-disk/hf_export\n",
      "Copied tokenizer.json to /mnt/persistent-disk/hf_export\n",
      "Copied .gitattributes to /mnt/persistent-disk/hf_export\n",
      "Copied special_tokens_map.json to /mnt/persistent-disk/hf_export\n",
      "Copied tokenizer_config.json to /mnt/persistent-disk/hf_export\n",
      "Copied generation_config.json to /mnt/persistent-disk/hf_export\n",
      "Copied config.json to /mnt/persistent-disk/hf_export\n",
      "All tokenizer files saved to /mnt/persistent-disk/hf_export\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "convert_lib.save_hf_compatible_checkpoint(\n",
    "    f'flax_params::{flax_checkpoint_path}', HF_EXPORT_DIR, model_configurator)\n",
    "\n",
    "# Download and save the tokenizer\n",
    "tokenizer_repo = f\"felafax/tokenizer-llama-3.1-8B-Instruct-JAX\"\n",
    "tokenizer_dir = snapshot_download(repo_id=tokenizer_repo)\n",
    "\n",
    "# Move all files from tokenizer_dir to HF_EXPORT_DIR\n",
    "for item in os.listdir(tokenizer_dir):\n",
    "    s = os.path.join(tokenizer_dir, item)\n",
    "    d = os.path.join(HF_EXPORT_DIR, item)\n",
    "    if os.path.isfile(s):\n",
    "        shutil.copy2(s, d)\n",
    "        print(f\"Copied {item} to {HF_EXPORT_DIR}\")\n",
    "    elif os.path.isdir(s):\n",
    "        shutil.copytree(s, d, dirs_exist_ok=True)\n",
    "        print(f\"Copied directory {item} to {HF_EXPORT_DIR}\")\n",
    "print(f\"All tokenizer files saved to {HF_EXPORT_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting to copy directory /mnt/persistent-disk/hf_export to /home/felafax-storage/checkpoints/llama-3.1-8B-Instruct-JAX/20240907_230836/\n",
      "Checkpoint copied to /home/felafax-storage/checkpoints/llama-3.1-8B-Instruct-JAX/20240907_230836/\n"
     ]
    }
   ],
   "source": [
    "checkpoint_lib.copy_directory(HF_EXPORT_DIR, GCS_DIR)\n",
    "print(f\"Checkpoint copied to {GCS_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HUGGINGFACE_TOKEN = input(\"INPUT: Please provide your HUGGINGFACE_TOKEN: \")\n",
    "# HUGGINGFACE_USERNAME = input(\n",
    "#     \"INPUT: Please provide your HUGGINGFACE_USERNAME: \")\n",
    "# HUGGINGFACE_REPO_NAME = input(\n",
    "#     \"INPUT: Please provide your HUGGINGFACE_REPO_NAME: \")\n",
    "# convert_lib.upload_checkpoint_to_hf(\n",
    "#     HF_EXPORT_DIR, f\"{HUGGINGFACE_USERNAME}/{HUGGINGFACE_REPO_NAME}\",\n",
    "#     HUGGINGFACE_TOKEN)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
